{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import tenseal as ts\n",
    "import numpy as np\n",
    "from skimage.util.shape import view_as_windows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def almost_equal(vec1, vec2, m_pow_ten):\n",
    "    if len(vec1) != len(vec2):\n",
    "        return False\n",
    "\n",
    "    upper_bound = pow(10, -m_pow_ten)\n",
    "    for v1, v2 in zip(vec1, vec2):\n",
    "        if abs(v1 - v2) > upper_bound:\n",
    "            return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Torch 2d convolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def torch_conv_2d(x, kernel, stride):\n",
    "    return torch.nn.functional.conv2d(\n",
    "        input=x, weight=kernel, stride=stride, padding=0, dilation=1\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input (4, 4)\n",
      "[[ 1  2  3  4]\n",
      " [ 5  6  7  8]\n",
      " [ 9 10 11 12]\n",
      " [13 14 15 16]]\n",
      "kernel (2, 2)\n",
      "[[1 2]\n",
      " [3 4]]\n"
     ]
    }
   ],
   "source": [
    "# input image dimension n * n\n",
    "x_size = 4\n",
    "# kernel dimension n * n\n",
    "k_size = 2\n",
    "# stride\n",
    "stride = 1\n",
    "\n",
    "# generated incremeneted values: 1, 2, ..., n^2\n",
    "x = np.arange(1, x_size ** 2 + 1).reshape(x_size, x_size)\n",
    "kernel = np.arange(1, k_size ** 2 + 1).reshape(k_size, k_size)\n",
    "\n",
    "# generated random values\n",
    "# x = np.random.randn(x_size, x_size)\n",
    "# kernel = np.random.randn(k_size, k_size)\n",
    "\n",
    "print(\"input\", x.shape)\n",
    "print(x)\n",
    "print(\"kernel\", kernel.shape)\n",
    "print(kernel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TenSEAL context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create TenSEAL context\n",
    "context = ts.context(\n",
    "    ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60]\n",
    ")\n",
    "# set the scale\n",
    "context.global_scale = pow(2, 40)\n",
    "# generated galois keys in order to do rotation on ciphertext vectors\n",
    "context.generate_galois_keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each convolution layer, a communication between the client and server is required. The server send the ciphertext (encrypted vector) to the client which is the input of the next convolution layer, in order to decrypt it and apply im2col (Image Block to Column) on the that input.\n",
    "\n",
    "After that the client encode and encrypt the input matrix in a vertical scan (column-major) and send it back to the server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "windows number:  9\n",
      "ckksvector size:  36\n",
      "9\n",
      "y_enc\n",
      "[44.00000566594027, 54.00000698409561, 64.00000865907336, 84.0000112607995, 94.0000125997811, 104.00001396028475, 124.00001662478137, 134.00001797146558, 144.00001931406726]\n",
      "CPU times: user 25.7 ms, sys: 0 ns, total: 25.7 ms\n",
      "Wall time: 24.6 ms\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "\n",
    "x_enc, windows_nb = ts.im2col_encoding(context, x, kernel.shape[0], kernel.shape[1], stride)\n",
    "\n",
    "print(\"windows number: \", windows_nb)\n",
    "print(\"ckksvector size: \", x_enc.size())\n",
    "\n",
    "y_enc = x_enc.conv2d_im2col(kernel.tolist(), windows_nb)\n",
    "\n",
    "print(y_enc.size())\n",
    "y_plain = y_enc.decrypt()\n",
    "\n",
    "print(\"y_enc\")\n",
    "print(y_plain)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare the result to torch conv2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y_toch\n",
      "[ 44.  54.  64.  84.  94. 104. 124. 134. 144.]\n"
     ]
    }
   ],
   "source": [
    "y_torch = torch_conv_2d(\n",
    "    torch.from_numpy(x.astype(\"float32\")).unsqueeze(0).unsqueeze(0),\n",
    "    torch.from_numpy(kernel.astype(\"float32\")).unsqueeze(0).unsqueeze(0),\n",
    "    stride\n",
    ")\n",
    "y_torch = y_torch.flatten().numpy()\n",
    "print(\"y_toch\")\n",
    "print(y_torch)\n",
    "\n",
    "assert almost_equal(y_plain, y_torch, 0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tenseal-venv",
   "language": "python",
   "name": "tenseal-venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
